{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/gogamid/avalanche/blob/master/notebooks/how-tos/dataloading_buffers_replay.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": true,
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "3ETjeh2Cf6-L"
      },
      "source": [
        "---\n",
        "description: How to implement replay and data loading\n",
        "---\n",
        "# Dataloading, Memory Buffers, and Replay\n",
        "\n",
        "Avalanche provides several components that help you to balance data loading and implement rehearsal strategies.\n",
        "\n",
        "**Dataloaders** are used to provide balancing between groups (e.g. tasks/classes/experiences). This is especially useful when you have unbalanced data.\n",
        "\n",
        "**Buffers** are used to store data from the previous experiences. They are dynamic datasets with a fixed maximum size, and they can be updated with new data continuously.\n",
        "\n",
        "Finally, **Replay** strategies implement rehearsal by using Avalanche's plugin system. Most rehearsal strategies use a custom dataloader to balance the buffer with the current experience and a buffer that is updated for each experience.\n",
        "\n",
        "First, let's install Avalanche. You can skip this step if you have installed it already."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "70_kgVQCf6-M",
        "outputId": "d1b04297-bcad-4aa3-8fd6-ba4937bb97e0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Requirement already satisfied: avalanche-lib in /usr/local/lib/python3.10/dist-packages (0.3.1)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (4.4.0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (5.9.5)\n",
            "Requirement already satisfied: gputil in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (1.4.0)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (1.2.2)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (3.7.1)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (1.23.5)\n",
            "Requirement already satisfied: pytorchcv in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (0.0.67)\n",
            "Requirement already satisfied: wandb in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (0.16.2)\n",
            "Requirement already satisfied: tensorboard>=1.15 in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (2.15.1)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (4.66.1)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (2.1.0+cu121)\n",
            "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (0.16.0+cu121)\n",
            "Requirement already satisfied: torchmetrics in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (1.3.0)\n",
            "Requirement already satisfied: gdown in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (4.6.6)\n",
            "Requirement already satisfied: quadprog in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (0.1.11)\n",
            "Requirement already satisfied: dill in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (0.3.7)\n",
            "Requirement already satisfied: setuptools<=59.5.0 in /usr/local/lib/python3.10/dist-packages (from avalanche-lib) (59.5.0)\n",
            "Requirement already satisfied: absl-py>=0.4 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (1.4.0)\n",
            "Requirement already satisfied: grpcio>=1.48.2 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (1.60.0)\n",
            "Requirement already satisfied: google-auth<3,>=1.6.3 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (2.17.3)\n",
            "Requirement already satisfied: google-auth-oauthlib<2,>=0.5 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (1.2.0)\n",
            "Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (3.5.1)\n",
            "Requirement already satisfied: protobuf<4.24,>=3.19.6 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (3.20.3)\n",
            "Requirement already satisfied: requests<3,>=2.21.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (2.31.0)\n",
            "Requirement already satisfied: six>1.9 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (1.16.0)\n",
            "Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (0.7.2)\n",
            "Requirement already satisfied: werkzeug>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from tensorboard>=1.15->avalanche-lib) (3.0.1)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from gdown->avalanche-lib) (3.13.1)\n",
            "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from gdown->avalanche-lib) (4.11.2)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (1.2.0)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (4.47.0)\n",
            "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (1.4.5)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (23.2)\n",
            "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (9.4.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (3.1.1)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->avalanche-lib) (2.8.2)\n",
            "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->avalanche-lib) (1.11.4)\n",
            "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->avalanche-lib) (1.3.2)\n",
            "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->avalanche-lib) (3.2.0)\n",
            "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->avalanche-lib) (1.12)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->avalanche-lib) (3.2.1)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch->avalanche-lib) (3.1.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch->avalanche-lib) (2023.6.0)\n",
            "Requirement already satisfied: triton==2.1.0 in /usr/local/lib/python3.10/dist-packages (from torch->avalanche-lib) (2.1.0)\n",
            "Requirement already satisfied: lightning-utilities>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from torchmetrics->avalanche-lib) (0.10.0)\n",
            "Requirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (8.1.7)\n",
            "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (3.1.41)\n",
            "Requirement already satisfied: sentry-sdk>=1.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (1.39.2)\n",
            "Requirement already satisfied: docker-pycreds>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (0.4.0)\n",
            "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (6.0.1)\n",
            "Requirement already satisfied: setproctitle in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (1.3.3)\n",
            "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb->avalanche-lib) (1.4.4)\n",
            "Requirement already satisfied: gitdb<5,>=4.0.1 in /usr/local/lib/python3.10/dist-packages (from GitPython!=3.1.29,>=1.0.0->wandb->avalanche-lib) (4.0.11)\n",
            "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15->avalanche-lib) (5.3.2)\n",
            "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15->avalanche-lib) (0.3.0)\n",
            "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.10/dist-packages (from google-auth<3,>=1.6.3->tensorboard>=1.15->avalanche-lib) (4.9)\n",
            "Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.10/dist-packages (from google-auth-oauthlib<2,>=0.5->tensorboard>=1.15->avalanche-lib) (1.3.1)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.15->avalanche-lib) (3.3.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.15->avalanche-lib) (3.6)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.15->avalanche-lib) (2.0.7)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.15->avalanche-lib) (2023.11.17)\n",
            "Requirement already satisfied: MarkupSafe>=2.1.1 in /usr/local/lib/python3.10/dist-packages (from werkzeug>=1.0.1->tensorboard>=1.15->avalanche-lib) (2.1.3)\n",
            "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->gdown->avalanche-lib) (2.5)\n",
            "Requirement already satisfied: PySocks!=1.5.7,>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.21.0->tensorboard>=1.15->avalanche-lib) (1.7.1)\n",
            "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->avalanche-lib) (1.3.0)\n",
            "Requirement already satisfied: smmap<6,>=3.0.1 in /usr/local/lib/python3.10/dist-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb->avalanche-lib) (5.0.1)\n",
            "Requirement already satisfied: pyasn1<0.6.0,>=0.4.6 in /usr/local/lib/python3.10/dist-packages (from pyasn1-modules>=0.2.1->google-auth<3,>=1.6.3->tensorboard>=1.15->avalanche-lib) (0.5.1)\n",
            "Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.10/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<2,>=0.5->tensorboard>=1.15->avalanche-lib) (3.2.2)\n"
          ]
        }
      ],
      "source": [
        "!pip install avalanche-lib"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "EUDkdUvmf6-M"
      },
      "source": [
        "## Dataloaders\n",
        "Avalanche dataloaders are simple iterators, located under `avalanche.benchmarks.utils.data_loader`. Their interface is equivalent to pytorch's dataloaders. For example, `GroupBalancedDataLoader` takes a sequence of datasets and iterates over them by providing balanced mini-batches, where the number of samples is split equally among groups. Internally, it instantiate a `DataLoader` for each separate group. More specialized dataloaders exist such as `TaskBalancedDataLoader`.\n",
        "\n",
        "All the dataloaders accept keyword arguments (`**kwargs`) that are passed directly to the dataloaders for each group."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "i7aukwdYf6-N",
        "outputId": "633e4cca-87f7-4c4b-f9cf-1ce27d89e4bd",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[0, 1, 2, 3, 4]\n"
          ]
        }
      ],
      "source": [
        "from avalanche.benchmarks import SplitMNIST\n",
        "from avalanche.benchmarks.utils.data_loader import GroupBalancedDataLoader\n",
        "benchmark = SplitMNIST(5, return_task_id=True)\n",
        "\n",
        "dl = GroupBalancedDataLoader([exp.dataset for exp in benchmark.train_stream], batch_size=5)\n",
        "for x, y, t in dl:\n",
        "    print(t.tolist())\n",
        "    break\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "3xEHbg9Sf6-N"
      },
      "source": [
        "## Memory Buffers\n",
        "Memory buffers store data up to a maximum capacity, and they implement policies to select which data to store and which the to remove when the buffer is full. They are available in the module `avalanche.training.storage_policy`. The base class is the `ExemplarsBuffer`, which implements two methods:\n",
        "- `update(strategy)` - given the strategy's state it updates the buffer (using the data in `strategy.experience.dataset`).\n",
        "- `resize(strategy, new_size)` - updates the maximum size and updates the buffer accordingly.\n",
        "\n",
        "The data can be access using the attribute `buffer`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "XU8oKhebf6-N",
        "outputId": "478f7e6d-5a36-4bb2-b402-a6c780eafaf2",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Max buffer size: 30, current size: 0\n"
          ]
        }
      ],
      "source": [
        "from avalanche.training.storage_policy import ReservoirSamplingBuffer\n",
        "from types import SimpleNamespace\n",
        "\n",
        "benchmark = SplitMNIST(5, return_task_id=False)\n",
        "storage_p = ReservoirSamplingBuffer(max_size=30)\n",
        "\n",
        "print(f\"Max buffer size: {storage_p.max_size}, current size: {len(storage_p.buffer)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "bps6v20lf6-N"
      },
      "source": [
        "At first, the buffer is empty. We can update it with data from a new experience.\n",
        "\n",
        "Notice that we use a `SimpleNamespace` because we want to use the buffer standalone, without instantiating an Avalanche strategy. Reservoir sampling requires only the `experience` from the strategy's state."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "dALuuRb6f6-N",
        "outputId": "d330f37f-5526-4a17-9df2-8339b87709f4",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Max buffer size: 30, current size: 30\n",
            "class targets: {8, 7}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {8, 0, 5, 7}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {0, 1, 2, 5, 7, 8}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {0, 1, 2, 3, 5, 7, 8, 9}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {0, 1, 2, 3, 5, 6, 7, 8, 9}\n",
            "\n"
          ]
        }
      ],
      "source": [
        "for i in range(5):\n",
        "    strategy_state = SimpleNamespace(experience=benchmark.train_stream[i])\n",
        "    storage_p.update(strategy_state)\n",
        "    print(f\"Max buffer size: {storage_p.max_size}, current size: {len(storage_p.buffer)}\")\n",
        "    print(f\"class targets: {storage_p.buffer.targets.uniques}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "sSLusXfqf6-N"
      },
      "source": [
        "Notice after each update some samples are substituted with new data. Reservoir sampling select these samples randomly.\n",
        "\n",
        "Avalanche offers many more storage policies. For example, `ParametricBuffer` is a buffer split into several groups according to the `groupby` parameters (`None`, 'class', 'task', 'experience'), and according to an optional `ExemplarsSelectionStrategy` (random selection is the default choice)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "LSK9mezrf6-N",
        "outputId": "769d7d46-ad4a-46be-f84c-03a1b5d1a1f7",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Max buffer size: 30, current size: 0\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {8, 7}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {8, 0, 5, 7}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {0, 1, 2, 5, 7, 8}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {0, 1, 2, 3, 5, 7, 8, 9}\n",
            "\n",
            "Max buffer size: 30, current size: 30\n",
            "class targets: {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}\n",
            "\n"
          ]
        }
      ],
      "source": [
        "from avalanche.training.storage_policy import ParametricBuffer, RandomExemplarsSelectionStrategy\n",
        "storage_p = ParametricBuffer(\n",
        "    max_size=30,\n",
        "    groupby='class',\n",
        "    selection_strategy=RandomExemplarsSelectionStrategy()\n",
        ")\n",
        "\n",
        "print(f\"Max buffer size: {storage_p.max_size}, current size: {len(storage_p.buffer)}\")\n",
        "for i in range(5):\n",
        "    strategy_state = SimpleNamespace(experience=benchmark.train_stream[i])\n",
        "    storage_p.update(strategy_state)\n",
        "    print(f\"Max buffer size: {storage_p.max_size}, current size: {len(storage_p.buffer)}\")\n",
        "    print(f\"class targets: {storage_p.buffer.targets.uniques}\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "fVLRip-Rf6-N"
      },
      "source": [
        "The advantage of using grouping buffers is that you get a balanced rehearsal buffer. You can even access the groups separately with the `buffer_groups` attribute. Combined with balanced dataloaders, you can ensure that the mini-batches stay balanced during training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "-VPZXPoEf6-O",
        "outputId": "440cb869-6796-4d99-800d-16a1a1bedc1e",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "(group 7) -> size 3\n",
            "(group 8) -> size 3\n",
            "(group 5) -> size 3\n",
            "(group 0) -> size 3\n",
            "(group 1) -> size 3\n",
            "(group 2) -> size 3\n",
            "(group 9) -> size 3\n",
            "(group 3) -> size 3\n",
            "(group 4) -> size 3\n",
            "(group 6) -> size 3\n"
          ]
        }
      ],
      "source": [
        "for k, v in storage_p.buffer_groups.items():\n",
        "    print(f\"(group {k}) -> size {len(v.buffer)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "wnOIVcmIf6-O",
        "outputId": "f1c043fa-23e0-423a-cecd-fdadd264a604",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "[7, 7, 7, 8, 8, 8, 5, 5, 5, 0, 0, 0, 1, 1, 1, 2, 2, 2, 9, 9, 9, 3, 3, 3, 4, 4, 4, 6, 6, 6]\n"
          ]
        }
      ],
      "source": [
        "datas = [v.buffer for v in storage_p.buffer_groups.values()]\n",
        "dl = GroupBalancedDataLoader(datas)\n",
        "\n",
        "for x, y, t in dl:\n",
        "    print(y.tolist())\n",
        "    break"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "efw1KMwMf6-O"
      },
      "source": [
        "## Replay Plugins\n",
        "\n",
        "Avalanche's strategy plugins can be used to update the rehearsal buffer and set the dataloader. This allows to easily implement replay strategies:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "5h1wobidf6-O"
      },
      "outputs": [],
      "source": [
        "from avalanche.benchmarks.utils.data_loader import ReplayDataLoader\n",
        "from avalanche.training.plugins import SupervisedPlugin\n",
        "\n",
        "class CustomReplay(SupervisedPlugin):\n",
        "    def __init__(self, storage_policy):\n",
        "        super().__init__()\n",
        "        self.storage_policy = storage_policy\n",
        "\n",
        "    def before_training_exp(self, strategy,\n",
        "                            num_workers: int = 0, shuffle: bool = True,\n",
        "                            **kwargs):\n",
        "        \"\"\" Here we set the dataloader. \"\"\"\n",
        "        if len(self.storage_policy.buffer) == 0:\n",
        "            # first experience. We don't use the buffer, no need to change\n",
        "            # the dataloader.\n",
        "            return\n",
        "\n",
        "        # replay dataloader samples mini-batches from the memory and current\n",
        "        # data separately and combines them together.\n",
        "        print(\"Override the dataloader.\")\n",
        "        strategy.dataloader = ReplayDataLoader(\n",
        "            strategy.adapted_dataset,\n",
        "            self.storage_policy.buffer,\n",
        "            oversample_small_tasks=True,\n",
        "            num_workers=num_workers,\n",
        "            batch_size=strategy.train_mb_size,\n",
        "            shuffle=shuffle)\n",
        "\n",
        "    def after_training_exp(self, strategy: \"BaseStrategy\", **kwargs):\n",
        "        \"\"\" We update the buffer after the experience.\n",
        "            You can use a different callback to update the buffer in a different place\n",
        "        \"\"\"\n",
        "        print(\"Buffer update.\")\n",
        "        self.storage_policy.update(strategy, **kwargs)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "pycharm": {
          "name": "#%% md\n"
        },
        "id": "QY8dLk32f6-O"
      },
      "source": [
        "And of course, we can use the plugin to train our continual model"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "pycharm": {
          "name": "#%%\n"
        },
        "id": "-BUUbK93f6-O",
        "outputId": "ef971fbb-8870-4db6-e7f1-ee24ae48a8a0",
        "colab": {
          "base_uri": "https://localhost:8080/"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Starting experiment...\n",
            "Start of experience  0\n",
            "-- >> Start of training phase << --\n",
            "100%|██████████| 122/122 [00:03<00:00, 33.00it/s]\n",
            "Epoch 0 ended.\n",
            "Buffer update.\n",
            "-- >> End of training phase << --\n",
            "Training completed\n",
            "Computing accuracy on the whole test set\n",
            "-- >> Start of eval phase << --\n",
            "-- Starting eval on experience 0 (Task 0) from test stream --\n",
            "100%|██████████| 21/21 [00:00<00:00, 45.32it/s]\n",
            "> Eval on experience 0 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9915\n",
            "-- Starting eval on experience 1 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 41.48it/s]\n",
            "> Eval on experience 1 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.0000\n",
            "-- Starting eval on experience 2 (Task 0) from test stream --\n",
            "100%|██████████| 22/22 [00:00<00:00, 46.76it/s]\n",
            "> Eval on experience 2 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0000\n",
            "-- Starting eval on experience 3 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 44.33it/s]\n",
            "> Eval on experience 3 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0000\n",
            "-- Starting eval on experience 4 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 42.56it/s]\n",
            "> Eval on experience 4 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
            "-- >> End of eval phase << --\n",
            "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.1993\n",
            "Start of experience  1\n",
            "-- >> Start of training phase << --\n",
            "Override the dataloader.\n",
            "100%|██████████| 119/119 [00:05<00:00, 22.27it/s]\n",
            "Epoch 0 ended.\n",
            "Buffer update.\n",
            "-- >> End of training phase << --\n",
            "Training completed\n",
            "Computing accuracy on the whole test set\n",
            "-- >> Start of eval phase << --\n",
            "-- Starting eval on experience 0 (Task 0) from test stream --\n",
            "100%|██████████| 21/21 [00:00<00:00, 43.59it/s]\n",
            "> Eval on experience 0 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9438\n",
            "-- Starting eval on experience 1 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 43.12it/s]\n",
            "> Eval on experience 1 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.9859\n",
            "-- Starting eval on experience 2 (Task 0) from test stream --\n",
            "100%|██████████| 22/22 [00:00<00:00, 42.94it/s]\n",
            "> Eval on experience 2 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.0000\n",
            "-- Starting eval on experience 3 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 45.31it/s]\n",
            "> Eval on experience 3 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0000\n",
            "-- Starting eval on experience 4 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 41.80it/s]\n",
            "> Eval on experience 4 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
            "-- >> End of eval phase << --\n",
            "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.3859\n",
            "Start of experience  2\n",
            "-- >> Start of training phase << --\n",
            "Override the dataloader.\n",
            "100%|██████████| 127/127 [00:04<00:00, 27.83it/s]\n",
            "Epoch 0 ended.\n",
            "Buffer update.\n",
            "-- >> End of training phase << --\n",
            "Training completed\n",
            "Computing accuracy on the whole test set\n",
            "-- >> Start of eval phase << --\n",
            "-- Starting eval on experience 0 (Task 0) from test stream --\n",
            "100%|██████████| 21/21 [00:00<00:00, 44.07it/s]\n",
            "> Eval on experience 0 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.9338\n",
            "-- Starting eval on experience 1 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 41.27it/s]\n",
            "> Eval on experience 1 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.9266\n",
            "-- Starting eval on experience 2 (Task 0) from test stream --\n",
            "100%|██████████| 22/22 [00:00<00:00, 42.52it/s]\n",
            "> Eval on experience 2 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.9948\n",
            "-- Starting eval on experience 3 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 44.70it/s]\n",
            "> Eval on experience 3 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.0000\n",
            "-- Starting eval on experience 4 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 42.38it/s]\n",
            "> Eval on experience 4 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
            "-- >> End of eval phase << --\n",
            "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.5825\n",
            "Start of experience  3\n",
            "-- >> Start of training phase << --\n",
            "Override the dataloader.\n",
            "100%|██████████| 116/116 [00:05<00:00, 21.70it/s]\n",
            "Epoch 0 ended.\n",
            "Buffer update.\n",
            "-- >> End of training phase << --\n",
            "Training completed\n",
            "Computing accuracy on the whole test set\n",
            "-- >> Start of eval phase << --\n",
            "-- Starting eval on experience 0 (Task 0) from test stream --\n",
            "100%|██████████| 21/21 [00:00<00:00, 44.35it/s]\n",
            "> Eval on experience 0 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.8672\n",
            "-- Starting eval on experience 1 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 41.18it/s]\n",
            "> Eval on experience 1 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.8447\n",
            "-- Starting eval on experience 2 (Task 0) from test stream --\n",
            "100%|██████████| 22/22 [00:00<00:00, 39.84it/s]\n",
            "> Eval on experience 2 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.9433\n",
            "-- Starting eval on experience 3 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 41.79it/s]\n",
            "> Eval on experience 3 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.9569\n",
            "-- Starting eval on experience 4 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 42.67it/s]\n",
            "> Eval on experience 4 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.0000\n",
            "-- >> End of eval phase << --\n",
            "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.7239\n",
            "Start of experience  4\n",
            "-- >> Start of training phase << --\n",
            "Override the dataloader.\n",
            "100%|██████████| 118/118 [00:04<00:00, 27.12it/s]\n",
            "Epoch 0 ended.\n",
            "Buffer update.\n",
            "-- >> End of training phase << --\n",
            "Training completed\n",
            "Computing accuracy on the whole test set\n",
            "-- >> Start of eval phase << --\n",
            "-- Starting eval on experience 0 (Task 0) from test stream --\n",
            "100%|██████████| 21/21 [00:00<00:00, 45.44it/s]\n",
            "> Eval on experience 0 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp000 = 0.7438\n",
            "-- Starting eval on experience 1 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 40.99it/s]\n",
            "> Eval on experience 1 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp001 = 0.8739\n",
            "-- Starting eval on experience 2 (Task 0) from test stream --\n",
            "100%|██████████| 22/22 [00:00<00:00, 44.47it/s]\n",
            "> Eval on experience 2 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp002 = 0.9603\n",
            "-- Starting eval on experience 3 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 32.68it/s]\n",
            "> Eval on experience 3 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp003 = 0.7171\n",
            "-- Starting eval on experience 4 (Task 0) from test stream --\n",
            "100%|██████████| 20/20 [00:00<00:00, 31.98it/s]\n",
            "> Eval on experience 4 (Task 0) from test stream ended.\n",
            "\tTop1_Acc_Exp/eval_phase/test_stream/Task000/Exp004 = 0.9622\n",
            "-- >> End of eval phase << --\n",
            "\tTop1_Acc_Stream/eval_phase/test_stream/Task000 = 0.8537\n"
          ]
        }
      ],
      "source": [
        "from torch.nn import CrossEntropyLoss\n",
        "from avalanche.training import Naive\n",
        "from avalanche.evaluation.metrics import accuracy_metrics\n",
        "from avalanche.training.plugins import EvaluationPlugin\n",
        "from avalanche.logging import InteractiveLogger\n",
        "from avalanche.models import SimpleMLP\n",
        "import torch\n",
        "\n",
        "scenario = SplitMNIST(5)\n",
        "model = SimpleMLP(num_classes=scenario.n_classes)\n",
        "storage_p = ParametricBuffer(\n",
        "    max_size=500,\n",
        "    groupby='class',\n",
        "    selection_strategy=RandomExemplarsSelectionStrategy()\n",
        ")\n",
        "\n",
        "# choose some metrics and evaluation method\n",
        "interactive_logger = InteractiveLogger()\n",
        "\n",
        "eval_plugin = EvaluationPlugin(\n",
        "    accuracy_metrics(experience=True, stream=True),\n",
        "    loggers=[interactive_logger])\n",
        "\n",
        "# CREATE THE STRATEGY INSTANCE (NAIVE)\n",
        "cl_strategy = Naive(model, torch.optim.Adam(model.parameters(), lr=0.001),\n",
        "                    CrossEntropyLoss(),\n",
        "                    train_mb_size=100, train_epochs=1, eval_mb_size=100,\n",
        "                    plugins=[CustomReplay(storage_p)],\n",
        "                    evaluator=eval_plugin\n",
        "                    )\n",
        "\n",
        "# TRAINING LOOP\n",
        "print('Starting experiment...')\n",
        "results = []\n",
        "for experience in scenario.train_stream:\n",
        "    print(\"Start of experience \", experience.current_experience)\n",
        "    cl_strategy.train(experience)\n",
        "    print('Training completed')\n",
        "\n",
        "    print('Computing accuracy on the whole test set')\n",
        "    results.append(cl_strategy.eval(scenario.test_stream))"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.9"
    },
    "colab": {
      "provenance": [],
      "include_colab_link": true
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}